# Implements API for Yi-VL in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# This script benefits from https://github.com/xusenlinzy/api-for-open-llm. Thanks for their wonderful works.
import json
import os
import time
import traceback
import uuid
from abc import ABC
from argparse import ArgumentParser
from contextlib import asynccontextmanager
from enum import Enum, IntEnum
from functools import lru_cache, partial
from threading import Thread
from types import MethodType
from typing import (
    Any,
    AsyncIterator,
    Dict,
    Iterator,
    List,
    Optional,
    Tuple,
    Type,
    Union,
    cast,
)

import anyio
import pydantic
import torch
import uvicorn
from anyio.streams.memory import MemoryObjectSendStream
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from llava.conversation import conv_templates
from llava.mm_utils import (
    expand2square,
    get_model_name_from_path,
    load_pretrained_model,
    tokenizer_image_token,
)
from llava.model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, key_info
from loguru import logger
from openai.types.chat import (
    ChatCompletion,
    ChatCompletionChunk,
    ChatCompletionMessage,
    ChatCompletionMessageParam,
    ChatCompletionToolChoiceOptionParam,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.chat.completion_create_params import ResponseFormat
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from pydantic import BaseModel
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool
from transformers import PreTrainedModel, PreTrainedTokenizer, TextIteratorStreamer


class Role(str, Enum):
    USER = "user"
    ASSISTANT = "assistant"
    SYSTEM = "system"
    FUNCTION = "function"
    TOOL = "tool"


class ErrorResponse(BaseModel):
    object: str = "error"
    message: str
    code: int


class ErrorCode(IntEnum):
    """
    https://platform.openai.com/docs/guides/error-codes/api-errors
    """

    VALIDATION_TYPE_ERROR = 40001

    INVALID_AUTH_KEY = 40101
    INCORRECT_AUTH_KEY = 40102
    NO_PERMISSION = 40103

    INVALID_MODEL = 40301
    PARAM_OUT_OF_RANGE = 40302
    CONTEXT_OVERFLOW = 40303

    RATE_LIMIT = 42901
    QUOTA_EXCEEDED = 42902
    ENGINE_OVERLOADED = 42903

    INTERNAL_ERROR = 50001
    CUDA_OUT_OF_MEMORY = 50002
    GRADIO_REQUEST_ERROR = 50003
    GRADIO_STREAM_UNKNOWN_ERROR = 50004
    CONTROLLER_NO_WORKER = 50005
    CONTROLLER_WORKER_TIMEOUT = 50006


class ChatCompletionCreateParams(BaseModel):
    messages: List[ChatCompletionMessageParam]
    """A list of messages comprising the conversation so far.
    [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
    """

    model: str
    """ID of the model to use.
    See the
    [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility)
    table for details on which models work with the Chat API.
    """

    frequency_penalty: Optional[float] = 0.0
    """Number between -2.0 and 2.0.
    Positive values penalize new tokens based on their existing frequency in the
    text so far, decreasing the model's likelihood to repeat the same line verbatim.
    [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
    """

    function_call: Optional[FunctionCall] = None
    """Deprecated in favor of `tool_choice`.
    Controls which (if any) function is called by the model. `none` means the model
    will not call a function and instead generates a message. `auto` means the model
    can pick between generating a message or calling a function. Specifying a
    particular function via `{"name": "my_function"}` forces the model to call that
    function.
    `none` is the default when no functions are present. `auto`` is the default if
    functions are present.
    """

    functions: Optional[List] = None
    """Deprecated in favor of `tools`.
    A list of functions the model may generate JSON inputs for.
    """

    logit_bias: Optional[Dict[str, int]] = None
    """Modify the likelihood of specified tokens appearing in the completion.
    Accepts a JSON object that maps tokens (specified by their token ID in the
    tokenizer) to an associated bias value from -100 to 100. Mathematically, the
    bias is added to the logits generated by the model prior to sampling. The exact
    effect will vary per model, but values between -1 and 1 should decrease or
    increase likelihood of selection; values like -100 or 100 should result in a ban
    or exclusive selection of the relevant token.
    """

    max_tokens: Optional[int] = None
    """The maximum number of [tokens](/tokenizer) to generate in the chat completion.
    The total length of input tokens and generated tokens is limited by the model's
    context length.
    [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
    for counting tokens.
    """

    n: Optional[int] = 1
    """How many chat completion choices to generate for each input message."""

    presence_penalty: Optional[float] = 0.0
    """Number between -2.0 and 2.0.
    Positive values penalize new tokens based on whether they appear in the text so
    far, increasing the model's likelihood to talk about new topics.
    [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
    """

    response_format: Optional[ResponseFormat] = None
    """An object specifying the format that the model must output.
    Used to enable JSON mode.
    """

    seed: Optional[int] = None
    """This feature is in Beta.
    If specified, our system will make a best effort to sample deterministically,
    such that repeated requests with the same `seed` and parameters should return
    the same result. Determinism is not guaranteed, and you should refer to the
    `system_fingerprint` response parameter to monitor changes in the backend.
    """

    stop: Optional[Union[str, List[str]]] = None
    """Up to 4 sequences where the API will stop generating further tokens."""

    temperature: Optional[float] = 0.9
    """What sampling temperature to use, between 0 and 2.
    Higher values like 0.8 will make the output more random, while lower values like
    0.2 will make it more focused and deterministic.
    We generally recommend altering this or `top_p` but not both.
    """

    tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None
    """
    Controls which (if any) function is called by the model. `none` means the model
    will not call a function and instead generates a message. `auto` means the model
    can pick between generating a message or calling a function. Specifying a
    particular function via
    `{"type: "function", "function": {"name": "my_function"}}` forces the model to
    call that function.
    `none` is the default when no functions are present. `auto` is the default if
    functions are present.
    """

    tools: Optional[List] = None
    """A list of tools the model may call.
    Currently, only functions are supported as a tool. Use this to provide a list of
    functions the model may generate JSON inputs for.
    """

    top_p: Optional[float] = 1.0
    """
    An alternative to sampling with temperature, called nucleus sampling, where the
    model considers the results of the tokens with top_p probability mass. So 0.1
    means only the tokens comprising the top 10% probability mass are considered.
    We generally recommend altering this or `temperature` but not both.
    """

    user: Optional[str] = None
    """
    A unique identifier representing your end-user, which can help OpenAI to monitor
    and detect abuse.
    [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
    """

    stream: Optional[bool] = False
    """If set, partial message deltas will be sent, like in ChatGPT.
    Tokens will be sent as data-only
    [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
    as they become available, with the stream terminated by a `data: [DONE]`
    message.
    [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
    """

    # Addictional parameters
    repetition_penalty: Optional[float] = 1.03
    """The parameter for repetition penalty. 1.0 means no penalty.
    See[this paper](https://arxiv.org / pdf / 1909.05858.pdf) for more details.
    """

    typical_p: Optional[float] = None
    """Typical Decoding mass.
    See[Typical Decoding for Natural Language Generation](https://arxiv.org / abs / 2202.00666) for more information
    """

    watermark: Optional[bool] = False
    """Watermarking with [A Watermark for Large Language Models](https://arxiv.org / abs / 2301.10226)
    """

    best_of: Optional[int] = 1

    ignore_eos: Optional[bool] = False

    use_beam_search: Optional[bool] = False

    stop_token_ids: Optional[List[int]] = None

    skip_special_tokens: Optional[bool] = True

    spaces_between_special_tokens: Optional[bool] = True

    min_p: Optional[float] = 0.0


@asynccontextmanager
async def lifespan(app: FastAPI):  # collects GPU memory
    yield
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()


app = FastAPI(lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.post("/v1/chat/completions")
async def create_chat_completion(
    request: ChatCompletionCreateParams, raw_request: Request
):
    if len(request.messages) < 1 or request.messages[-1]["role"] == Role.ASSISTANT:
        raise HTTPException(status_code=400, detail="Invalid request")

    request = await handle_request(request, engine.template.stop)
    request.max_tokens = request.max_tokens or 1024

    params = model_dump(request, exclude={"messages"})
    params.update(dict(prompt_or_messages=request.messages, echo=False))
    logger.debug(f"==== request ====\n{params}")

    iterator_or_completion = await run_in_threadpool(
        engine.create_chat_completion, params
    )

    if isinstance(iterator_or_completion, Iterator):
        # It's easier to ask for forgiveness than permission
        first_response = await run_in_threadpool(next, iterator_or_completion)

        # If no exception was raised from first_response, we can assume that
        # the iterator is valid, and we can use it to stream the response.
        def iterator() -> Iterator:
            yield first_response
            yield from iterator_or_completion

        send_chan, recv_chan = anyio.create_memory_object_stream(10)
        return EventSourceResponse(
            recv_chan,
            data_sender_callable=partial(
                get_event_publisher,
                request=raw_request,
                inner_send_chan=send_chan,
                iterator=iterator(),
            ),
        )
    else:
        return iterator_or_completion


server_error_msg = (
    "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
)


@torch.inference_mode()
def generate_stream(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    params: Dict[str, Any],
):
    input_ids = params.get("inputs")
    image_tensor = params.get("image_tensor")
    has_image = params.get("has_image", False)

    model_name = params.get("model", "llm")
    temperature = float(params.get("temperature", 1.0))
    top_p = float(params.get("top_p", 1.0))
    top_k = int(params.get("top_k", 40))
    max_new_tokens = int(params.get("max_tokens", 1024))

    stop_token_ids = params.get("stop_token_ids") or []
    if tokenizer.eos_token_id not in stop_token_ids:
        stop_token_ids.append(tokenizer.eos_token_id)
    stop_strings = params.get("stop", [])
    input_echo_len = len(input_ids)
    device = model.device
    input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
    if has_image:
        image_tensor = torch.tensor(
            image_tensor, dtype=torch.bfloat16, device=device
        ).unsqueeze(0)
    generation_kwargs = dict(
        input_ids=input_ids,
        images=image_tensor,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.pad_token_id,
    )
    if temperature <= 1e-5:
        generation_kwargs["do_sample"] = False
        generation_kwargs.pop("top_k")

    streamer = TextIteratorStreamer(
        tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
    )
    generation_kwargs["streamer"] = streamer

    if "GenerationMixin" not in str(model.generate.__func__):
        model.generate = MethodType(PreTrainedModel.generate, model)

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    generated_text, func_call_found = "", False
    completion_id: str = f"cmpl-{str(uuid.uuid4())}"
    created: int = int(time.time())
    previous_text = ""
    for i, new_text in enumerate(streamer):
        generated_text += new_text
        generated_text, stop_found = apply_stopping_strings(
            generated_text, stop_strings
        )

        if generated_text and generated_text[-1] != "�":
            delta_text = generated_text[len(previous_text) :]
            previous_text = generated_text

            yield {
                "id": completion_id,
                "object": "text_completion",
                "created": created,
                "model": model_name,
                "delta": delta_text,
                "text": generated_text,
                "logprobs": None,
                "finish_reason": "function_call" if func_call_found else None,
                "usage": {
                    "prompt_tokens": input_echo_len,
                    "completion_tokens": i,
                    "total_tokens": input_echo_len + i,
                },
            }

        if stop_found:
            break

    yield {
        "id": completion_id,
        "object": "text_completion",
        "created": created,
        "model": model_name,
        "delta": "",
        "text": generated_text,
        "logprobs": None,
        "finish_reason": "stop",
        "usage": {
            "prompt_tokens": input_echo_len,
            "completion_tokens": i,
            "total_tokens": input_echo_len + i,
        },
    }


class DefaultEngine(ABC):
    """基于原生 transformers 实现的模型引擎"""

    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        image_processor,
        device: Union[str, torch.device],
        model_name: str,
    ) -> None:
        """
        Initialize the Default class.

        Args:
            model (PreTrainedModel): The pre-trained model.
            tokenizer (PreTrainedTokenizer): The tokenizer for the model.
            device (Union[str, torch.device]): The device to use for inference.
            model_name (str): The name of the model.
        """
        self.model = model
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.device = model.device if hasattr(model, "device") else torch.device(device)

        self.model_name = model_name.lower()
        self.template = YiAITemplate()

        self._prepare_for_generate()

    def _prepare_for_generate(self) -> None:
        """
        Prepare the object for text generation.

        1. Sets the appropriate generate stream function based on the model name and type.
        2. Updates the context length if necessary.
        3. Checks and constructs the prompt.
        4. Sets the context length if it is not already set.
        """
        self.generate_stream_func = generate_stream

        self.context_len = get_context_length(self.model.config)

    def convert_to_inputs(
        self, prompt_or_messages: Union[List[ChatCompletionMessageParam], str]
    ) -> Tuple[
        Union[List[int], Dict[str, Any]], Union[List[ChatCompletionMessageParam], str]
    ]:
        """
        Convert the prompt or messages into input format for the model.

        Args:
            prompt_or_messages: The prompt or messages to be converted.

        Returns:
            Tuple containing the converted inputs and the prompt or messages.
        """
        query = prompt_or_messages[0]["content"][0]["text"]
        image_file = prompt_or_messages[0]["content"][1]["image_url"]["url"]
        if image_file != "" and image_file != None:
            query = DEFAULT_IMAGE_TOKEN + "\n" + query

        conv = conv_templates["mm_default"].copy()
        stop_str = conv.sep
        conv.append_message(conv.roles[0], query)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX)

        if image_file != "" and image_file != None:
            if image_file.startswith("http"):  # url
                from io import BytesIO

                import requests

                response = requests.get(image_file)
                if response.status_code == 200:
                    image_bytes = BytesIO(response.content)
                    image = Image.open(image_bytes)
            else:
                image = Image.open(image_file)  # local path

            if getattr(self.model.config, "image_aspect_ratio", None) == "pad":
                image = expand2square(
                    image, tuple(int(x * 255) for x in self.image_processor.image_mean)
                )
            image_tensor = self.image_processor.preprocess(image)["pixel_values"][0]
        else:
            image_tensor = None

        return input_ids, image_tensor, stop_str

    def _generate(self, params: Dict[str, Any]) -> Iterator[dict]:
        """
        Generates text based on the given parameters.

        Args:
            params (Dict[str, Any]): A dictionary containing the parameters for text generation.

        Yields:
            Iterator: A dictionary containing the generated text and error code.
        """
        prompt_or_messages = params.get("prompt_or_messages")
        input_ids, image_tensor, stop_str = self.convert_to_inputs(prompt_or_messages)
        image_file = prompt_or_messages[0]["content"][1]["image_url"]["url"]
        has_image = image_file != "" and image_file != None

        params.update(
            dict(inputs=input_ids, image_tensor=image_tensor, has_image=has_image)
        )
        params["stop"].append(stop_str)

        try:
            for output in self.generate_stream_func(self.model, self.tokenizer, params):
                output["error_code"] = 0
                yield output

        except torch.cuda.OutOfMemoryError as e:
            yield {
                "text": f"{server_error_msg}\n\n({e})",
                "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
            }

        except (ValueError, RuntimeError) as e:
            traceback.print_exc()
            yield {
                "text": f"{server_error_msg}\n\n({e})",
                "error_code": ErrorCode.INTERNAL_ERROR,
            }

    def _create_chat_completion_stream(
        self, params: Dict[str, Any]
    ) -> Iterator[ChatCompletionChunk]:
        """
        Creates a chat completion stream.

        Args:
            params (Dict[str, Any]): The parameters for generating the chat completion.

        Yields:
            Dict[str, Any]: The output of the chat completion stream.
        """
        _id, _created, _model = None, None, None
        has_function_call = False
        for i, output in enumerate(self._generate(params)):
            if output["error_code"] != 0:
                yield output
                return

            _id, _created, _model = output["id"], output["created"], output["model"]
            if i == 0:
                choice = ChunkChoice(
                    index=0,
                    delta=ChoiceDelta(role="assistant", content=""),
                    finish_reason=None,
                    logprobs=None,
                )
                yield ChatCompletionChunk(
                    id=f"chat{_id}",
                    choices=[choice],
                    created=_created,
                    model=_model,
                    object="chat.completion.chunk",
                )

            finish_reason = output["finish_reason"]
            if len(output["delta"]) == 0 and finish_reason != "function_call":
                continue

            delta = ChoiceDelta(content=output["delta"])

            choice = ChunkChoice(
                index=0,
                delta=delta,
                finish_reason=finish_reason,
                logprobs=None,
            )
            yield ChatCompletionChunk(
                id=f"chat{_id}",
                choices=[choice],
                created=_created,
                model=_model,
                object="chat.completion.chunk",
            )

        if not has_function_call:
            choice = ChunkChoice(
                index=0,
                delta=ChoiceDelta(),
                finish_reason="stop",
                logprobs=None,
            )
            yield ChatCompletionChunk(
                id=f"chat{_id}",
                choices=[choice],
                created=_created,
                model=_model,
                object="chat.completion.chunk",
            )

    def _create_chat_completion(
        self, params: Dict[str, Any]
    ) -> Union[ChatCompletion, JSONResponse]:
        """
        Creates a chat completion based on the given parameters.

        Args:
            params (Dict[str, Any]): The parameters for generating the chat completion.

        Returns:
            ChatCompletion: The generated chat completion.
        """
        last_output = None
        for output in self._generate(params):
            last_output = output

        if last_output["error_code"] != 0:
            return create_error_response(last_output["error_code"], last_output["text"])

        finish_reason = "stop"

        message = ChatCompletionMessage(
            role="assistant",
            content=last_output["text"].strip(),
        )

        choice = Choice(
            index=0,
            message=message,
            finish_reason=finish_reason,
            logprobs=None,
        )
        usage = model_parse(CompletionUsage, last_output["usage"])
        return ChatCompletion(
            id=f"chat{last_output['id']}",
            choices=[choice],
            created=last_output["created"],
            model=last_output["model"],
            object="chat.completion",
            usage=usage,
        )

    def create_chat_completion(
        self,
        params: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> Union[Iterator[ChatCompletionChunk], ChatCompletion]:
        params = params or {}
        params.update(kwargs)
        return (
            self._create_chat_completion_stream(params)
            if params.get("stream", False)
            else self._create_chat_completion(params)
        )

    @property
    def stop(self):
        """
        Gets the stop property of the prompt adapter.

        Returns:
            The stop property of the prompt adapter, or None if it does not exist.
        """
        return self.template.stop if hasattr(self.template, "stop") else None


class YiAITemplate(ABC):
    """https://huggingface.co/01-ai/Yi-34B-Chat/blob/main/tokenizer_config.json"""

    name = "yi"
    system_prompt: Optional[str] = ""
    allow_models = ["yi"]
    stop = {
        "strings": ["<|endoftext|>", "<|im_end|>"],
        "token_ids": [
            2,
            6,
            7,
            8,
        ],  # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"
    }
    function_call_available: Optional[bool] = False

    def apply_chat_template(
        self,
        conversation: List[ChatCompletionMessageParam],
        add_generation_prompt: bool = True,
    ) -> str:
        """
        Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a prompt.
        Args:
            conversation (List[ChatCompletionMessageParam]): A Conversation object or list of dicts
                with "role" and "content" keys, representing the chat history so far.
            add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
                the start of an assistant message. This is useful when you want to generate a response from the model.
                Note that this argument will be passed to the chat template, and so it must be supported in the
                template for this argument to have any effect.
        Returns:
            `str`: A prompt, which is ready to pass to the tokenizer.
        """
        # Compilation function uses a cache to avoid recompiling the same template
        compiled_template = _compile_jinja_template(self.template)
        return compiled_template.render(
            messages=conversation,
            add_generation_prompt=add_generation_prompt,
            system_prompt=self.system_prompt,
        )

    @property
    def template(self) -> str:
        return (
            "{% for message in messages %}"
            "{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}"
            "{% endfor %}"
            "{% if add_generation_prompt %}"
            "{{ '<|im_start|>assistant\\n' }}"
            "{% endif %}"
        )

    def postprocess_messages(
        self, messages: List[ChatCompletionMessageParam]
    ) -> List[Dict[str, Any]]:
        return messages

    def parse_assistant_response(
        self, output: StopIteration
    ) -> Tuple[str, Optional[Union[str, Dict[str, Any]]]]:
        return output, None


@lru_cache
def _compile_jinja_template(chat_template: str):
    """
    Compile a Jinja template from a string.
    Args:
        chat_template (str): The string representation of the Jinja template.
    Returns:
        jinja2.Template: The compiled Jinja template.
    Examples:
        >>> template_string = "Hello, {{ name }}!"
        >>> template = _compile_jinja_template(template_string)
    """
    try:
        from jinja2.exceptions import TemplateError
        from jinja2.sandbox import ImmutableSandboxedEnvironment
    except ImportError:
        raise ImportError("apply_chat_template requires jinja2 to be installed.")

    def raise_exception(message):
        raise TemplateError(message)

    jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
    jinja_env.globals["raise_exception"] = raise_exception
    return jinja_env.from_string(chat_template)


async def handle_request(
    request: Union[ChatCompletionCreateParams], stop: Dict[str, Any] = None
) -> Union[Union[ChatCompletionCreateParams], JSONResponse]:
    error_check_ret = check_requests(request)
    if error_check_ret is not None:
        raise error_check_ret

    # stop settings
    _stop, _stop_token_ids = [], []
    if stop is not None:
        _stop_token_ids = stop.get("token_ids", [])
        _stop = stop.get("strings", [])

    request.stop = request.stop or []
    if isinstance(request.stop, str):
        request.stop = [request.stop]

    if request.functions:
        request.stop.append("Observation:")

    request.stop = list(set(_stop + request.stop))
    request.stop_token_ids = request.stop_token_ids or []
    request.stop_token_ids = list(set(_stop_token_ids + request.stop_token_ids))

    return request


def check_requests(
    request: Union[ChatCompletionCreateParams],
) -> Optional[JSONResponse]:
    # Check all params
    if request.max_tokens is not None and request.max_tokens <= 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
        )
    if request.n is not None and request.n <= 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.n} is less than the minimum of 1 - 'n'",
        )
    if request.temperature is not None and request.temperature < 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.temperature} is less than the minimum of 0 - 'temperature'",
        )
    if request.temperature is not None and request.temperature > 2:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
        )
    if request.top_p is not None and request.top_p < 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.top_p} is less than the minimum of 0 - 'top_p'",
        )
    if request.top_p is not None and request.top_p > 1:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
        )
    if request.stop is None or isinstance(request.stop, (str, list)):
        return None
    else:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.stop} is not valid under any of the given schemas - 'stop'",
        )


def create_error_response(code: int, message: str) -> JSONResponse:
    return JSONResponse(
        model_dump(ErrorResponse(message=message, code=code)), status_code=500
    )


async def get_event_publisher(
    request: Request,
    inner_send_chan: MemoryObjectSendStream,
    iterator: Union[Iterator, AsyncIterator],
):
    async with inner_send_chan:
        try:
            async for chunk in iterate_in_threadpool(iterator):
                if isinstance(chunk, BaseModel):
                    chunk = model_json(chunk)
                elif isinstance(chunk, dict):
                    chunk = json.dumps(chunk, ensure_ascii=False)

                await inner_send_chan.send(dict(data=chunk))

                if await request.is_disconnected():
                    raise anyio.get_cancelled_exc_class()()

            await inner_send_chan.send(dict(data="[DONE]"))
        except anyio.get_cancelled_exc_class() as e:
            logger.info("disconnected")
            with anyio.move_on_after(1, shield=True):
                logger.info(
                    f"Disconnected from client (via refresh/close) {request.client}"
                )
                raise e


def create_generate_model(args):
    """get generate model for chat or completion."""
    model_path = os.path.expanduser(args.model_path)
    key_info["model_path"] = model_path
    get_model_name_from_path(model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(model_path)

    logger.info("Using default engine")

    return DefaultEngine(
        model, tokenizer, image_processor, "cuda", model_name=args.model_name
    )


# --------------- Pydantic v2 compatibility ---------------

PYDANTIC_V2 = pydantic.VERSION.startswith("2.")


def model_json(model: pydantic.BaseModel, **kwargs) -> str:
    if PYDANTIC_V2:
        return model.model_dump_json(**kwargs)
    return model.json(**kwargs)  # type: ignore


def model_dump(model: pydantic.BaseModel, **kwargs) -> Dict[str, Any]:
    if PYDANTIC_V2:
        return model.model_dump(**kwargs)
    return cast(
        "dict[str, Any]",
        model.dict(**kwargs),
    )


def model_parse(model: Type[pydantic.BaseModel], data: Any) -> pydantic.BaseModel:
    if PYDANTIC_V2:
        return model.model_validate(data)
    return model.parse_obj(data)  # pyright: ignore[reportDeprecated]


# Models don't use the same configuration key for determining the maximum
# sequence length.  Store them here so we can sanely check them.
# NOTE: The ordering here is important.  Some models have two of these, and we
# have a preference for which value gets used.
SEQUENCE_LENGTH_KEYS = [
    "max_sequence_length",
    "seq_length",
    "max_position_embeddings",
    "max_seq_len",
    "model_max_length",
]


def get_context_length(config) -> int:
    """Get the context length of a model from a huggingface model config."""
    rope_scaling = getattr(config, "rope_scaling", None)
    rope_scaling_factor = config.rope_scaling["factor"] if rope_scaling else 1
    for key in SEQUENCE_LENGTH_KEYS:
        val = getattr(config, key, None)
        if val is not None:
            return int(rope_scaling_factor * val)
    return 2048


def apply_stopping_strings(reply: str, stop_strings: List[str]) -> Tuple[str, bool]:
    """
    Apply stopping strings to the reply and check if a stop string is found.
    Args:
        reply (str): The reply to apply stopping strings to.
        stop_strings (List[str]): The list of stopping strings to check for.
    Returns:
        Tuple[str, bool]: A tuple containing the modified reply and a boolean indicating if a stop string was found.
    """
    stop_found = False
    for string in stop_strings:
        idx = reply.find(string)
        if idx != -1:
            reply = reply[:idx]
            stop_found = True
            break

    if not stop_found:
        # If something like "\nYo" is generated just before "\nYou: is completed, trim it
        for string in stop_strings:
            for j in range(len(string) - 1, 0, -1):
                if reply[-j:] == string[:j]:
                    reply = reply[:-j]
                    break
            else:
                continue

            break

    return reply, stop_found


def _get_args():
    parser = ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8000, help="Demo server port.")

    # model related
    parser.add_argument("--model-path", type=str, default="01-ai/Yi-VL-34B")
    parser.add_argument("--model-name", type=str, default="yi-vl")

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = _get_args()

    engine = create_generate_model(args)
    uvicorn.run(app, host=args.host, port=args.port, log_level="info")
